---
id: mnist-digit-classification
sidebar_position: 7
title: MNIST Digit Classification
---

import ExampleDiffCodeTabs from '@site/src/components/ExampleDiffCodeTabs';
import MNISTDemoExamples from '@site/src/components/examples/MNISTDemoExamples';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
import SurveyLinkButton from '@site/src/components/SurveyLinkButton';

### In this tutorial we will use a model trained on the MNIST dataset of handwritten digits to predict the number that the user draws.

There are several pieces to this tutorial, so please follow each step carefully. If you get lost, completed examples of each step can be found [here](https://github.com/facebookresearch/playtorch/tree/main/examples/mnist-digit-classification/).

If you haven't installed the PyTorch Live CLI yet, please [follow this tutorial](./get-started.mdx) to get started.

## Create a new React Native project

We will start by creating a new React Native project with the PyTorch Live (PTL) template using the CLI. Run the following command:

```shell
npx torchlive-cli init MNISTClassifier
```

Once that is done, let's go into a our newly created project and run it!

```shell
cd MNISTClassifier
```

<Tabs
  defaultValue="android"
  values={[
    {label: 'Android', value: 'android'},
    {label: 'iOS (Simulator)', value: 'ios'},
  ]}>
  <TabItem value="android">

  ```shell
  npx torchlive-cli run-android
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/android/first-run.png "Screenshot of app after fresh init with CLI")

  </TabItem>
  <TabItem value="ios">

  ```shell
  npx torchlive-cli run-ios
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/ios/first-run.png "Screenshot of app after fresh init with CLI")

  </TabItem>
</Tabs>

## Adding Basic UI

The aim of this tutorial is to help you become more familiar with PTL core components, so we will not spend time on how to style UI, but rather provide the layout and styles from the start.

Go ahead and start by copying the following code into the file `src/demos/MyDemos.tsx`:

<MNISTDemoExamples.V0CodeBlock title="src/demos/MyDemos.tsx" />

Now you should see UI that looks exactly like the screenshot below.

<Tabs
  defaultValue="android"
  values={[
    {label: 'Android', value: 'android'},
    {label: 'iOS (Simulator)', value: 'ios'},
  ]}>
  <TabItem value="android">

  ```shell
  npx torchlive-cli run-android
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/android/initial-ui.png "Screenshot of UI created with initial code")

  </TabItem>
  <TabItem value="ios">

  ```shell
  npx torchlive-cli run-ios
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/ios/initial-ui.png "Screenshot of UI created with initial code")

  </TabItem>
</Tabs>

Before we add more code, let's take a second to discuss some of what the above code does.

### The PyTorch Live Canvas Component

We'll be using the PTL canvas in this tutorial to let the user draw numbers that we will try to classify.

Just like the name suggests, a canvas is a surface that we can programmatically draw on.

In order to draw things on a canvas, we use what is called the canvas context, the `ctx` state variable in this case.

Note that we haven't used the context to draw anything yet, so our canvas is essentially invisible.

```tsx
...
export default function MNISTDemo() {
  ...
  const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
  ...
      <Canvas
        style={{
          height: canvasSize,
          width: canvasSize,
        }}
        onContext2D={setCtx}
      />
...
```

### The `onLayout` Prop

In our code, we use the `onLayout` prop on the container view to get the dimensions of the screen space we are working with.

Once we have the dimensions of the screen, we find which is smaller between the screen width and height and then we use that to size our canvas.

This makes sure that our canvas is square and fits within the bounds of our screen in both portrait and landscape.


```tsx
...
export default function MNISTDemo() {
  // Get safe area insets to account for notches, etc.
  const insets = useSafeAreaInsets();
  const [canvasSize, setCanvasSize] = useState<number>(0);
  ...
  return (
    <View
      style={styles.container}
      onLayout={event => {
        const {layout} = event.nativeEvent;
        setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
      }}>
...
```

### Results placeholders

Note that for now we just have placeholder text where we will put our model results. Later on, after we run the model, we will update the text there to display the results.

```tsx
...
  <View style={[styles.resultView]} pointerEvents="none">
    <Text style={[styles.label, styles.secondary]}>
      Highest confidence will go here
    </Text>
    <Text style={[styles.label, styles.secondary]}>
      Second highest will go here
    </Text>
  </View>
...
```

## Filling the Canvas

Like we mentioned in the previous section, our canvas is currently completely blank.

Let's change that and make a clear surface for users to draw on.

Here's a short summary of the changes we're introducing:
1. Import `useCallback` and `useEffect` from React.
2. Define a color for our canvas background (`COLOR_CANVAS_BACKGROUND`). We'll use a lighter purple color to distinguish from the rest of the screen.
3. Create a `draw` function that will fill in our background. We create it with `useCallback` to make it so the function updates every time the context or size of the canvas change.

    1. Check to make sure context is not null so we have something to draw with.

    2. Set the context's fill style to our canvas background purple (essentially choosing which marker to work with).

    3. Fill in a rectangle that starts at the origin coordinate (0,0) on our canvas (the top left corner) and ends in the bottom right corner of our canvas so it covers the whole thing.

    4. Call the `invalidate` function on our canvas context to let the screen know that we have drawn new things for it to show.

4. Trigger the `draw` anytime it changes with the `useEffect` block. Remember that `draw` changes every time the canvas context or size changes, so essentially this `useEffect` runs every time the canvas changes.

:::note

The `useCallback` and `useEffect` that we imported as well as the `useState` function we already had imported are examples of **React Hooks**. Hooks allow React function components, like our `MNISTDemo` function component, to remember things.

You'll notice at the end of `useCallback` and `useEffect` we have a list `[]`. This list is the list of "dependencies" for that hook. This just means that the hook will hold onto the value we give it until one of the "dependencies" changes, in which case it will update the value it remembers.

For more information on React Hooks, head over to the [React docs](https://reactjs.org/docs/hooks-intro.html) where you can read or watch explanations.

:::

<ExampleDiffCodeTabs>
  <MNISTDemoExamples.V1DiffBlock />
  <MNISTDemoExamples.V1CodeBlock title="src/demos/MyDemos.tsx" />
</ExampleDiffCodeTabs>

Once you run your app, the My Demos screen should now look like this.

<Tabs
  defaultValue="android"
  values={[
    {label: 'Android', value: 'android'},
    {label: 'iOS (Simulator)', value: 'ios'},
  ]}>
  <TabItem value="android">

  ```shell
  npx torchlive-cli run-android
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/android/fill-canvas-bg.png "Screenshot of app after painting canvas light purple")

  </TabItem>
  <TabItem value="ios">

  ```shell
  npx torchlive-cli run-ios
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/ios/fill-canvas-bg.png "Screenshot of app after painting canvas light purple")

  </TabItem>
</Tabs>

I know that was a lot of new stuff to simply paint our canvas light purple, but it provides us with a good foundation for when we draw more on our canvas.

## Drawing with Touch Input

Now that we have a clear area for the user to draw on, let's make it so they can draw!

Let's go over what we will change to make drawing possible:
1. Import `useRef` from React.
2. Define a color for the trail of the users touch (`COLOR_TRAIL_STROKE`). We'll use white to make it stand out.
3. Define a `TrailPoint` type to keep our data safe, error free, and easy to use.
4. Create a ref to a list of `TrailPoints` called `trailRef` and set it to an empty list.
5. Keep track of if the user has finished drawing with the `drawingDone` state variable and initialize it to `false`.
6. Add support for drawing the trail to our draw function:

    1. Create a variable called `trail` and set it to the current value of our `trailRef`. This is purely so we don't have to write `trailRef.current` every time we need the trail.
    2. Check to make sure the trail isn't null.
    3. Draw our background to cover anything previously drawn.
    4. Check to make sure our trail has at least 1 point.
    5. Set the context's `strokeColor` - you can think of it as picking the marker color we'll draw lines with.
    6. Set the context's line drawing style parameters (`lineWidth`, `lineJoin`, `lineCap`, and `miterLimit`).
    7. Tell the context to start a line at the first point in the trail.
    8. Loop through points of the trail to add them to the line we are drawing.
    9. Tell the context via the `stroke` method to actually draw the line that we constructed.
    10. Use the `invalidate` method to tell the screen we have updates ready to draw.

7. Create functions for handling when a user touches the canvas (`handleStart`, `handleTouch`, and `handleEnd`).
8. The `handleStart` is called when the user first touches the canvas. It is a simple function that does the following:

    1. Set the `drawingDone` variable to `false`.
    2. Reset the trailRef to an `emptyList`.

9. The `handleMove` function is called each time the device detects that the touch has changed positions since the starting touch.

    1. Get the coordinates of the new touch location and store them in the `position` variable.
    2. If there are already points in the `trail`, only add the new position if it's 5 pixels away from the last position (avoids keeping unnecessary points that slow down the app).
    3. If there are no points in the `trail`, add the new position.
    4. Trigger the `draw` function to display the newly updated `trail`.

10. The `handleEnd` function is called when the user's touch is no longer detected on the screen.

    1. Simply set the `drawingDone` state variable to `true`.

11. Set the `onTouchStart`, `onTouchMove`, and `onTouchEnd` props on our `<Canvas />` component to `handleStart`, `handleMove`, and `handleEnd` respectively.

<ExampleDiffCodeTabs>
  <MNISTDemoExamples.V2DiffBlock />
  <MNISTDemoExamples.V2CodeBlock title="src/demos/MyDemos.tsx" />
</ExampleDiffCodeTabs>

Run this code and we should now be able to do some drawing like you can see in the video below.

![](/img/tutorial-0.1/mnist-digit-classification/android/initial-laggy-drawing.gif "Screen recording of drawing on canvas with laggy screen")

As you will notice, the drawing seems to glitch out at times, especially as the trail gets longer and longer. Let's fix that next.

:::info

#### React Refs

Refs in React are a variable like state, but they don't cause the component to re-render when they are changed.

You can get or set the value of a ref via the `.current` property.

In our code, we access the trail with `trailRef.current`. We set the trail in our `handleStart` function to an empty list with `trailRef.current = []`.

:::

## Avoiding Excessive Re-rendering

The glitchiness we see in our code as it stands is because we are asking the screen to refresh before it is ready.

Mobile screens typically refresh 60 times per second (though some new phones refresh twice as often). When we display things with React, it takes care of matching our device's refresh rate.

While we are using React to render our `<Canvas />`, what we draw on our canvas we handle ourselves. Lucky for us, there is a simple way to make sure we don't render too often.

To address this, we will make a few updates to our code, mainly in the `draw` function:

1. Create a ref called `animationHandleRef` that can be a `number` or `null` and set it to `null`. We will use this ref to check if rendering is currently in process or not.
2. Use the `animationHandleRef` in the `draw` function to control how often we rerender:

    1. Start the function by checking if the `animationHandleRef` is set to a non-null value. If it is, we want to end early, because we know the device is already working on rendering.
    2. Wrap our code that does drawing in an inline function that we pass to `requestAnimationFrame` and set the `animationHandleRef`'s value to what it returns. (Read more about this function in the note following the code.)
    3. After telling our canvas we are ready for a rerender with `ctx.invalidate()`, clear the `animationHandleRef` by setting its value to null.
    4. Add `animationHandleRef` to the `draw` function's callback dependencies list.

<ExampleDiffCodeTabs>
  <MNISTDemoExamples.V3DiffBlock />
  <MNISTDemoExamples.V3CodeBlock title="src/demos/MyDemos.tsx" />
</ExampleDiffCodeTabs>

:::info

#### What does `requestAnimationFrame` do?

`requestAnimationFrame` is a utility function that helps us run code when the screen is ready for the next rerender.

Input: a callback function as a parameter and then runs that function when the screen next refreshes.

Output: a number that functions as an ID for the callback. You can use that number to cancel the callback if you later decide you don't want to run the code. (We don't need that feature for this)

:::

Once you have those changes in your code, go ahead and refresh the app and see how much smoother drawing is.

<Tabs
  defaultValue="android"
  values={[
    {label: 'Android', value: 'android'},
    {label: 'iOS (Simulator)', value: 'ios'},
  ]}>
  <TabItem value="android">

  ```shell
  npx torchlive-cli run-android
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/android/smooth-drawing.gif "Screen recording of drawing on canvas with smooth animation")

  </TabItem>
  <TabItem value="ios">

  ```shell
  npx torchlive-cli run-ios
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/ios/smooth-drawing.gif "Screen recording of drawing on canvas with smooth animation")

  </TabItem>
</Tabs>

With silky smooth drawing in place, we are now ready to start working with the MNIST model.

### Running the Model

We'll start by creating a React hook that provides a function for running inference on an input image. We'll follow React hooks naming conventions and call ours `useMNISTModel`.

Let's summarize the changes we're making:
1. Import `Image` and `MobileModel` from `react-native-pytorch-core`.
2. Load the model file with the `require` function and call it `mnistModel`.
3. Create a type called `MNISTResult` with the following properties:
    1. `num` -  a digit from 0 to 9.
    2. `score` - the confidence the model has in the input image being the given `num`.
4. Define a function called `useMNISTModel` that does the following:
    1. Creates a React callback async function called `processImage` that takes in `Image` as a parameter and does the following.
        1. Uses the `MobileModel` api to execute the `mnistModel` we loaded with a set of parameters that tell the model how much of the image to use and what the foreground and background colors are.
        2. Transform the raw scores into `MNISTResult` objects.
        3. Sort the results by `score`.
        4. return the sorted results.
    2. Returns an object containing the `processImage` function we just created.

<ExampleDiffCodeTabs>
  <MNISTDemoExamples.V4DiffBlock />
  <MNISTDemoExamples.V4CodeBlock title="src/demos/MyDemos.tsx" />
</ExampleDiffCodeTabs>

An even shorter summary: it takes in an `Image` and gives back a list of sorted results.

But, we don't have `Image`s. We just have a trail on a canvas.

In the next section, we'll learn how to create an `Image` from the contents of our canvas that we can pass to the model.

## Creating an Image from our Canvas

We are going to create another hook called `useMNISTCanvasInference` that uses the hook we just created (`useMNISTModel`).

This hook will take in the `canvasSize` and give us back two things:
1. `result` - a state variable that holds the sorted list of `MNISTResult`s from the model.
2. `classify` - a function that takes in the `canvas` context, extracts an image from it, processes the image, and then updates the `result` state variable.

In our `classify` callback, we use some of the PTL core components, including the newly imported `ImageUtil` object.

The `ImageUtil` object allows us to take the `imageData` we pull from the canvas and turn it into an `Image` that can be used by our model.

You'll also see that we call the `release` function on both our `imageData` and our `image` variables as soon as we are done using them. This is a vital step to make sure we don't run out of memory on images we no longer need.

<ExampleDiffCodeTabs>
  <MNISTDemoExamples.V5DiffBlock />
  <MNISTDemoExamples.V5CodeBlock title="src/demos/MyDemos.tsx" />
</ExampleDiffCodeTabs>

With this second hook, we are ready to run our model with the user created drawings. Let's hook it up in the next section.

## Running the Model & Displaying Results

While we add a decent amount of lines in this section, they are all simple changes.

Let's cut to the summary:
1. Create a type called `NumberLabelSet` so we know what kind of data we have access to about a number.
2. Create a list of `NumberLabelSet` objects and call it `numLabels`.
3. Get the `classify` method and `result` state variable by calling `useMNISTCanvasInference` from within our demo component.
4. Update the `handleEnd` function to check for a canvas context and then trigger the model.
5. Add `classify` as a dependency to the `handleEnd` callback function.
6. Change the text in the results section to reflect the numbers from the model output.

<ExampleDiffCodeTabs>
  <MNISTDemoExamples.V6DiffBlock />
  <MNISTDemoExamples.V6CodeBlock title="src/demos/MyDemos.tsx" />
</ExampleDiffCodeTabs>

When you run the code, you should see it display results properly in the bottom left corner like the screen recording below.

<Tabs
  defaultValue="android"
  values={[
    {label: 'Android', value: 'android'},
    {label: 'iOS (Simulator)', value: 'ios'},
  ]}>
  <TabItem value="android">

  ```shell
  npx torchlive-cli run-android
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/android/model-running.gif "Screen recording of user drawing numbers and model results displaying correct answers")

  </TabItem>
  <TabItem value="ios">

  ```shell
  npx torchlive-cli run-ios
  ```

  ![](/img/tutorial-0.1/mnist-digit-classification/ios/model-running.gif "Screen recording of user drawing numbers and model results displaying correct answers")

  </TabItem>
</Tabs>

And with that we have a working MNIST classifier!

## Give us feedback

<SurveyLinkButton docTitle="MNIST Digit Classification" />
